{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型参数的访问、初始化和共享\n",
    "\n",
    "在[“线性回归的简洁实现”]一节中，我们通过`init`模块来初始化模型的全部参数。我们也介绍了访问模型参数的简单方法。本节将深入讲解如何访问和初始化模型参数，以及如何在多个层之间共享同一份模型参数。\n",
    "\n",
    "我们先定义一个与上一节中相同的含单隐藏层的多层感知机。我们依然使用默认方式初始化它的参数，并做一次前向计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 10), dtype=float32, numpy=\n",
       "array([[-0.17193761, -0.0922666 , -0.13595092, -0.11826225,  0.35631424,\n",
       "         0.23411953,  0.18851405, -0.29137024, -0.07208098, -0.32998022],\n",
       "       [ 0.04340871, -0.2755222 , -0.44214192, -0.02530663,  0.6582413 ,\n",
       "         0.12312036,  0.24236704, -0.12245701, -0.3845298 , -0.5933588 ]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = tf.keras.models.Sequential()\n",
    "net.add(tf.keras.layers.Flatten())\n",
    "net.add(tf.keras.layers.Dense(256,activation=tf.nn.relu))\n",
    "net.add(tf.keras.layers.Dense(10))\n",
    "\n",
    "X = tf.random.uniform((2,20))\n",
    "Y = net(X)\n",
    "Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.1 access model parameters\n",
    "\n",
    "对于使用`Sequential`类构造的神经网络，我们可以通过weights属性来访问网络任一层的权重。回忆一下上一节中提到的`Sequential`类与`tf.keras.Model`类的继承关系。对于`Sequential`实例中含模型参数的层，我们可以通过`tf.keras.Model`类的`weights`属性来访问该层包含的所有参数。下面，访问多层感知机`net`中隐藏层的所有参数。索引0表示隐藏层为`Sequential`实例最先添加的层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Variable 'sequential/dense/kernel:0' shape=(20, 256) dtype=float32, numpy=\n",
       " array([[ 0.0481915 , -0.12631316,  0.01195522, ...,  0.02723822,\n",
       "          0.02255017, -0.12760542],\n",
       "        [ 0.0930911 , -0.01770672,  0.02268262, ...,  0.14521751,\n",
       "          0.14541116, -0.06110913],\n",
       "        [ 0.04290995, -0.11152898,  0.064651  , ..., -0.13156009,\n",
       "         -0.09439682,  0.1101175 ],\n",
       "        ...,\n",
       "        [-0.09814546,  0.00927126, -0.11610402, ...,  0.04871327,\n",
       "         -0.09315487, -0.05659153],\n",
       "        [ 0.06200621, -0.09238889,  0.04176129, ..., -0.14487112,\n",
       "          0.06835862, -0.06371358],\n",
       "        [ 0.13075373, -0.04125882,  0.04096027, ...,  0.05943713,\n",
       "         -0.11458447, -0.08245869]], dtype=float32)>,\n",
       " tensorflow.python.ops.resource_variable_ops.ResourceVariable)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.weights[0], type(net.weights[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.2 initialize params\n",
    "\n",
    "我们在[“数值稳定性和模型初始化”]一节中描述了模型的默认初始化方法：权重参数元素为[-0.07, 0.07]之间均匀分布的随机数，偏差参数则全为0。但我们经常需要使用其他方法来初始化权重。在下面的例子中，我们将权重参数初始化成均值为0、标准差为0.01的正态分布随机数，并依然将偏差参数清零。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 4.51393833e-04,  1.57772638e-02, -1.86072791e-03,\n",
       "          9.84685216e-03,  1.17519898e-02, -3.12424754e-03,\n",
       "         -9.02889855e-03, -2.59149168e-02,  6.66950271e-03,\n",
       "         -3.18642822e-03],\n",
       "        [ 8.68504588e-03,  7.08601344e-03,  1.74540130e-03,\n",
       "          1.08063575e-02, -1.17103141e-02, -3.71275260e-03,\n",
       "         -5.68187982e-03,  1.98864192e-03, -5.00063365e-03,\n",
       "         -1.03827775e-03],\n",
       "        [-9.89197753e-03, -4.18905262e-03,  1.65334251e-02,\n",
       "         -5.04891668e-03, -9.47635528e-03, -1.82326464e-03,\n",
       "          1.91288700e-06, -1.32915629e-02,  1.11705707e-02,\n",
       "         -1.58366859e-02],\n",
       "        [ 1.27096614e-02,  1.15885818e-02, -2.00654278e-04,\n",
       "         -8.13968387e-03,  3.31314979e-04,  1.26107298e-02,\n",
       "         -6.44148933e-03, -7.69741414e-03, -3.03531438e-03,\n",
       "         -1.46293696e-02],\n",
       "        [ 2.77605117e-03, -6.10801065e-03, -2.19253893e-03,\n",
       "         -3.63684585e-03,  1.28923450e-02, -2.79305899e-03,\n",
       "         -1.15096420e-02,  7.75661319e-03,  5.74074034e-03,\n",
       "         -3.94103210e-03],\n",
       "        [ 3.16805235e-05,  1.18821743e-03,  1.02334339e-02,\n",
       "         -2.42349803e-02,  3.11323069e-02, -6.91231014e-03,\n",
       "          6.53658900e-03, -6.82671787e-04, -6.15805434e-03,\n",
       "          1.38388411e-03],\n",
       "        [-5.29898272e-04,  3.57993436e-03, -2.16954947e-02,\n",
       "          1.97736789e-02, -6.09996496e-03, -3.37204291e-03,\n",
       "         -2.47666077e-03, -8.14034697e-03,  2.72408929e-02,\n",
       "         -1.02747809e-02],\n",
       "        [-7.33285118e-03,  7.44258426e-03,  5.09619154e-03,\n",
       "          6.58885157e-03, -9.93944355e-04, -1.13407392e-02,\n",
       "          1.35631058e-02, -4.52223746e-03, -4.08853963e-03,\n",
       "         -1.29554374e-02],\n",
       "        [ 1.67277008e-02, -8.95517878e-03, -5.93329640e-03,\n",
       "          8.10202397e-03,  1.16005843e-03, -7.15200265e-04,\n",
       "          1.79609824e-02, -4.64062719e-03, -9.43446811e-03,\n",
       "          5.23514627e-03],\n",
       "        [ 3.49203753e-03,  1.62120722e-03,  7.17066461e-04,\n",
       "          1.57900073e-03,  5.83383115e-03,  5.89554943e-03,\n",
       "         -6.63165329e-03,  2.77736527e-03, -3.02705122e-03,\n",
       "         -2.25389283e-03],\n",
       "        [-2.39626411e-03, -2.01081764e-02, -9.41819139e-03,\n",
       "          2.23807991e-04, -8.25551152e-03,  1.24586355e-02,\n",
       "          5.82520571e-03, -1.30321952e-02, -1.68187777e-03,\n",
       "         -4.91316617e-03],\n",
       "        [-1.83669236e-02,  5.32839389e-04,  3.08534643e-03,\n",
       "         -9.60799772e-03, -1.01758521e-02, -5.07762842e-03,\n",
       "         -9.87021625e-03,  9.36440099e-03, -6.13500131e-04,\n",
       "         -7.96126015e-03],\n",
       "        [-2.63913558e-03, -4.45986493e-03, -9.76439659e-03,\n",
       "         -6.42446242e-03,  3.94819118e-03, -1.31137511e-02,\n",
       "          3.64012923e-03,  1.19350553e-02,  9.75352339e-03,\n",
       "         -1.79445231e-03],\n",
       "        [ 1.39155835e-02,  1.77140143e-02,  1.27706314e-02,\n",
       "         -5.93537278e-03, -5.72185311e-03, -2.25463528e-02,\n",
       "         -1.54848322e-02,  2.78064894e-04, -9.24902223e-03,\n",
       "         -8.18772241e-03],\n",
       "        [-1.09455734e-03, -3.82284774e-03, -2.12653484e-02,\n",
       "          1.80865508e-02, -1.44651753e-03, -1.08752481e-03,\n",
       "         -1.12096537e-02,  5.67936013e-03, -4.58946591e-03,\n",
       "         -5.27656171e-03],\n",
       "        [ 9.37363226e-03, -1.73870027e-02, -1.54605729e-03,\n",
       "         -7.35433819e-03,  7.65736680e-03, -4.15695977e-04,\n",
       "         -1.04166539e-02,  1.26024438e-02,  2.09770631e-02,\n",
       "          6.54655322e-03],\n",
       "        [ 1.15516102e-02, -1.98801458e-02, -2.48490926e-03,\n",
       "         -9.94555838e-03, -4.13135951e-03,  1.66282095e-02,\n",
       "         -1.18775126e-02,  3.23300622e-03, -1.27977105e-02,\n",
       "         -2.80499067e-02],\n",
       "        [-6.05542585e-03, -8.90592486e-03, -1.22854451e-03,\n",
       "          1.53648220e-02,  1.61505805e-03, -4.27642325e-03,\n",
       "          3.21729970e-03,  7.03209825e-03, -2.59752199e-03,\n",
       "          4.40714508e-03],\n",
       "        [-1.11575872e-02,  1.38611933e-02, -4.55705216e-03,\n",
       "          5.11751394e-04, -1.33980699e-02,  9.65256616e-03,\n",
       "          2.50446057e-04,  2.05488931e-02,  7.39773782e-03,\n",
       "         -3.29138786e-02],\n",
       "        [-6.95713982e-03,  3.55372159e-03,  6.74101571e-03,\n",
       "          5.75583987e-03, -4.85088537e-03, -4.07720770e-04,\n",
       "         -1.26192234e-02, -6.69517671e-04, -2.13461705e-02,\n",
       "         -3.41037405e-03]], dtype=float32),\n",
       " array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
       " array([[1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.]], dtype=float32),\n",
       " array([1.], dtype=float32)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Linear(tf.keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.d1 = tf.keras.layers.Dense(\n",
    "            units=10,\n",
    "            activation=None,\n",
    "            kernel_initializer=tf.random_normal_initializer(mean=0,stddev=0.01),\n",
    "            bias_initializer=tf.zeros_initializer()\n",
    "        )\n",
    "        self.d2 = tf.keras.layers.Dense(\n",
    "            units=1,\n",
    "            activation=None,\n",
    "            kernel_initializer=tf.ones_initializer(),\n",
    "            bias_initializer=tf.ones_initializer()\n",
    "        )\n",
    "\n",
    "    def call(self, input):\n",
    "        output = self.d1(input)\n",
    "        output = self.d2(output)\n",
    "        return output\n",
    "\n",
    "net = Linear()\n",
    "net(X)\n",
    "net.get_weights()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.3 define initializer\n",
    "\n",
    "可以使用`tf.keras.initializers`类中的方法实现自定义初始化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'sequential_1/dense_4/kernel:0' shape=(20, 64) dtype=float32, numpy=\n",
       "array([[1., 1., 1., ..., 1., 1., 1.],\n",
       "       [1., 1., 1., ..., 1., 1., 1.],\n",
       "       [1., 1., 1., ..., 1., 1., 1.],\n",
       "       ...,\n",
       "       [1., 1., 1., ..., 1., 1., 1.],\n",
       "       [1., 1., 1., ..., 1., 1., 1.],\n",
       "       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def my_init():\n",
    "    return tf.keras.initializers.Ones()\n",
    "\n",
    "model = tf.keras.models.Sequential()\n",
    "model.add(tf.keras.layers.Dense(64, kernel_initializer=my_init()))\n",
    "\n",
    "Y = model(X)\n",
    "model.weights[0]"
   ]
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python [conda env:tf2]",
   "language": "python",
   "name": "conda-env-tf2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  },
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
